# coding:utf-8
'''
author:wangyi
'''
import thulac
from gensim.models import KeyedVectors
from gensim.models.word2vec import Word2Vec
import time,datetime


# 将原始语料转化为word2vec的训练语料
def parse_xml_to_corpus(inf,outf):
    '''
    :param inf: 原始语料路径
    :param outf: 训练语料路径
    :return:
    '''
    # 实例化分词工具
    thu = thulac.thulac(seg_only=True)
    out = open(outf,mode='w',encoding='utf-8')
    with open(inf,encoding='utf-8') as f:
        for i,line in enumerate(f.readlines()):
            # 去除原语料中的<content>标签对
            line = line.strip().replace('<content>','').replace('</content>','')
            # 分词
            line = thu.cut(line,text=True)
            # 逐行写入
            out.write(line+'\n')
            # 刷新文件
            out.flush()
            print(i+1,'is writing')


# word2vec训练
def train_w2v(inf,outf,embedding_dim=128,window=5,min_count=5,cbow=0,epoch=5):
    '''
    :param inf: 训练语料路径
    :param embedding_dim: 词向量维度,默认128
    :param window: 窗口大小 默认5
    :param min_count: 进词表的最小词频 默认5
    :param cbow: cbow(1)/skip-gram(0) 默认skip-gram
    :param epoch: 迭代轮次 默认5
    :param outf: 词向量文件输出路径
    :return:
    '''
    model = Word2Vec(corpus_file=inf,size=embedding_dim,window=window, min_count=min_count,cbow_mean=cbow,iter=epoch)
    model.wv.save_word2vec_format(outf)

# 获取topN相似词汇
def get_similar_words(inf,topN=10):
    '''
    :param inf: 词向量模型路径
    :param topN: 获取前N个相似词
    :return:
    '''
    start = time.time()
    # 读取预训练模型
    model = KeyedVectors.load_word2vec_format(inf)
    print('model loaded success! 用时:',datetime.timedelta(seconds=int(time.time()-start)))
    while 1:
        print('请输入查询词汇：')
        # 从命令行获取查询词
        word = input()
        # 回车或者空格退出
        if word == '' or word.isspace():
            break
        if word in model.vocab.keys():
            for key,cos in model.similar_by_word(word,topn=topN):
                print(key,':',cos)
        else:
            print('the word selected is not in vocab')



if __name__ == '__main__':

    #parse_xml_to_corpus('corpus.txt','corpus_train.txt')
    #train_w2v('corpus_train.txt','w2v.txt',min_count=1,epoch=50)
    get_similar_words('w2v.txt')





